from sklearn.preprocessing import MinMaxScaler
import numpy as np
from pyDOE import lhs
import forward
import os
import tough_py
import pickle
import matplotlib.pyplot as plt
import shutil
import time
from matplotlib.colors import ListedColormap
from scipy import stats
import pandas as pd

start_time = time.time()  
def cross_cov(X,Y):
	p = np.size(X, 0)
	m = np.size(X, 1)
	n = np.size(Y, 1) 
	space = np.concatenate((X, Y), axis=1)
	cov_space = np.cov(space.T)
	R = cov_space[0:m, m: m + n]
	return R

def create_folder(folder_path):
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)
        print(f"delete old files: {folder_path}")
    

    os.makedirs(folder_path)
    print(f"creat new files: {folder_path}")

folder_path1 = './Mengdi/seismicity'  
create_folder(folder_path1)
folder_path2 = './Mengdi/tracer'  
create_folder(folder_path2)
folder_path3 = './Mengdi/initial_data'  
create_folder(folder_path3)
folder_path4 = './Mengdi/storage'  
create_folder(folder_path4)

iter_num = 20
inflate = 20 

dimension=78 
fracturepoints=2500 
tracerpoints=61
enlarge = 8e4

O1=np.loadtxt('kernel38.txt', delimiter='	', dtype = str, skiprows=0)
O1 = O1.astype('float64')
O1 = O1 * enlarge
O2=np.loadtxt('tracer38.txt', delimiter='	', dtype = str, skiprows=0)
O2 = O2.astype('float64')
ub = np.ones([dimension,])*0.01
lb = np.ones([dimension,])*0.001

NPar = 50 
NVar = dimension  
NMea =tracerpoints+fracturepoints  
Tolerance = 1e-4 
x=np.random.randn(NPar,dimension) 
x = np.array(x)
np.savetxt(r'./Mengdi/initial_data/x78.txt', x, fmt='%.5E')

vhat = 0.1
k = 0
Tra1 = []
Tra = []
EstimationX = []
Particle_traceY = []
xall=[]
yall=[]
saveX=[]
savebioty1=[]
savey=[]

C_e = (0.05**2)/inflate

while (vhat >= Tolerance)&(k <iter_num):
	X = forward.kaidukringing(x)  
	X = X.reshape((X.shape[0], -1))
	saveX.append(X)
	idx = np.arange(NPar).reshape((NPar, 1)) + 1
	X = np.concatenate((idx, X), axis=1)
	X = X.astype('float64')
	idf = X[:, 0].astype('uint8')
	X = X[:, 1:]
	main_dir = os.getcwd()
	y11,y2 = tough_py.extract(main_dir, X, idf, k) 
	y11 = y11.astype('float64')
	y1 = y11 * enlarge
	y2 = y2.astype('float64')
	savebioty1.append(y11)
	savey.append(y2)

    #error
	obs_tracer_error = np.tile(O2, (NPar, 1)) + (inflate**0.5) * np.random.normal(loc = 0, scale = (C_e **0.5), size = (NPar, tracerpoints))   
	obs_kernel_error = np.tile(np.log10(O1), (NPar, 1)) + (inflate**0.5) * np.random.normal(loc = 0, scale = (C_e **0.5), size = (NPar, fracturepoints))
	obs_kernel_error = 10**obs_kernel_error

	Oand = np.concatenate((obs_tracer_error, obs_kernel_error), axis=1)
	Obs = Oand.astype('float64')

	y = np.concatenate((y2, y1), axis=1)
	Bias = y - Obs  
    
	vhat = np.mean(np.mean(abs(Bias))) 
	Tra.append(vhat)

	Dxy = cross_cov(x, y)
	Dyy = cross_cov(y, y)

	Rt = inflate * (C_e) * np.diag(np.ones(NMea))    
	# update
	Gt = np.dot(Dxy,np.linalg.pinv(Dyy + Rt))
	add=np.transpose(np.dot(Gt, np.transpose(Bias)))
	x = x - 0.5 * add

	k = k + 1
	np.savetxt(r'./Mengdi/storage/x{}.txt'.format(k), x)
	print('k = {}'.format(k))
    
#%%save results
saveX1=np.vstack(saveX)  
print(saveX1.shape)
savey1=np.vstack(savey) 
print(savey1.shape)
Tra1=np.vstack(Tra)       
print(Tra1.shape)
SaveBiot = np.vstack(savebioty1)    

np.savetxt(r'./Mengdi/storage/saveX1.txt',saveX1,fmt='%.4e')   
np.savetxt(r'./Mengdi/storage/savey1.txt',savey1,fmt='%.4e')   
np.savetxt(r'./Mengdi/storage/Tra1.txt',Tra1,fmt='%.4e')       
np.savetxt(r'./Mengdi/storage/SaveBiot.txt',SaveBiot,fmt='%.4e')  

end_time = time.time()  
elapsed_time = end_time - start_time  

hours = int(elapsed_time // 3600)
minutes = int((elapsed_time % 3600) // 60)
seconds = int(elapsed_time % 60)

elapsed_text = f"elapsed time: {hours}h {minutes}min {seconds}s"
with open("./Mengdi/storage/execution_time.txt", "w", encoding="utf-8") as file:
    file.write(elapsed_text)
    
    
#%%plot    
tracer_concentration = savey1
observe_micro = np.loadtxt('kernel38.txt')
observe_tracer = np.loadtxt('tracer38.txt')
time = list(range(1, 62))
for i in range(0, 50*(iter_num - 2)):     
    plt.plot(time , tracer_concentration[i,:],  color='#9a9fa2')        
for j in range(50*(iter_num - 2), 50*(iter_num - 0)):   
    plt.plot(time , tracer_concentration[j,:],  color='#676c6e') 
    
plt.scatter(time , observe_tracer,  color='#e42313', zorder = 2)       

plt.xlabel('Time (day)')  
plt.ylabel('C (mg/L)')    
plt.title('Inversion 1 - Breakthrough curves of IES in sythetic case')

plt.savefig(r'./Mengdi/storage/Breakthrough.png', dpi = 600) 
plt.close()

def calculate_mse(y_pred, y_true):
    y_true = np.array(y_true).astype(float)
    y_pred = np.array(y_pred).astype(float)
    
    errors = y_true - y_pred
    
    squared_errors = errors ** 2
    mse = np.mean(squared_errors)   
    return mse

iter_times = list(range(1, 50*(iter_num - 0) + 1))         
iter_times = np.array(iter_times)
mse_tracerall = []
mse_microall = []
micro  = np.loadtxt(r'./Mengdi/storage/SaveBiot.txt')
micro  = np.array(micro) * enlarge
for i in range(0, tracer_concentration.shape[0]):

    mse_tracer = calculate_mse(tracer_concentration[i, :], observe_tracer)
    mse_tracerall.append(mse_tracer)
    mse_micro = calculate_mse(micro[i, :], observe_micro)
    mse_microall.append(mse_micro)
                                
mse_tracerall = np.array(mse_tracerall)   
mse_microall = np.array(mse_microall)                           
mse_all = mse_microall + mse_tracerall                     
rmse_all = np.sqrt(mse_all)
#%%RMSE   

plt.figure(figsize=(8, 12))
plt.subplot(3, 1, 1)
plt.scatter(iter_times, rmse_all, color='#00457c')
plt.xlabel('iteration') 
plt.ylabel('RMSE')  
plt.subplot(3, 1, 2)
plt.scatter(iter_times, mse_tracerall, color='#00457c')
plt.xlabel('iteration') 
plt.ylabel('MSEtracer') 
plt.subplot(3, 1, 3)
plt.scatter(iter_times, mse_microall, color='#00457c')
plt.xlabel('iteration')  
plt.ylabel('MSEmicro')  
plt.title('Inversion 1 - RMSE of IES in sythetic case')

plt.subplots_adjust(hspace=0.4)    
plt.savefig(r'./Mengdi/storage/RMSE.png', dpi=600)  
plt.close()